{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 建立神经网络\n",
    "\n",
    "[![](https://gitee.com/mindspore/docs/raw/master/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/master/tutorials/source_zh_cn/model.ipynb)&emsp;[![](https://gitee.com/mindspore/docs/raw/master/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/master/quick_start/mindspore_model.ipynb)&emsp;[![](https://gitee.com/mindspore/docs/raw/master/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9vYnMuZHVhbHN0YWNrLmNuLW5vcnRoLTQubXlodWF3ZWljbG91ZC5jb20vbWluZHNwb3JlLXdlYnNpdGUvbm90ZWJvb2svbW9kZWxhcnRzL3F1aWNrX3N0YXJ0L21pbmRzcG9yZV9tb2RlbC5pcHluYg==&imagename=MindSpore1.1.1)\n",
    "\n",
    "神经网络模型由多个数据操作层组成，`mindspore.nn`提供了各种网络基础模块。\n",
    "\n",
    "在以下内容中，我们将以构建LeNet网络为例，展示MindSpore是如何建立神经网络模型的。\n",
    "\n",
    "首先导入本文档需要的模块和接口，如下所示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore\n",
    "import mindspore.nn as nn\n",
    "from mindspore import Tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义模型类\n",
    "\n",
    "MindSpore的`Cell`类是构建所有网络的基类，也是网络的基本单元。当用户需要神经网络时，需要继承`Cell`类，并重写`__init__`方法和`construct`方法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LeNet5(nn.Cell):\n",
    "    \"\"\"\n",
    "    Lenet网络结构\n",
    "    \"\"\"\n",
    "    def __init__(self, num_class=10, num_channel=1):\n",
    "        super(LeNet5, self).__init__()\n",
    "        # 定义所需要的运算\n",
    "        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')\n",
    "        self.fc1 = nn.Dense(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Dense(120, 84)\n",
    "        self.fc3 = nn.Dense(84, num_class)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.flatten = nn.Flatten()\n",
    "\n",
    "    def construct(self, x):\n",
    "        # 使用定义好的运算构建前向网络\n",
    "        x = self.conv1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.max_pool2d(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.max_pool2d(x)\n",
    "        x = self.flatten(x)\n",
    "        x = self.fc1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型层\n",
    "\n",
    "本小节内容首先将会介绍LeNet网络中使用到`Cell`类的关键成员函数，然后通过实例化网络介绍如何利用`Cell`类访问模型参数。\n",
    "\n",
    "### nn.Conv2d\n",
    "\n",
    "加入`nn.Conv2d`层，给网络中加入卷积函数，帮助神经网络提取特征。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 6, 28, 28)\n"
     ]
    }
   ],
   "source": [
    "conv2d = nn.Conv2d(1, 6, 5, has_bias=False, weight_init='normal', pad_mode='valid')\n",
    "input_x = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32)\n",
    "\n",
    "print(conv2d(input_x).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### nn.ReLU\n",
    "\n",
    "加入`nn.ReLU`层，给网络中加入非线性的激活函数，帮助神经网络学习各种复杂的特征。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0. 2. 0. 2. 0.]\n"
     ]
    }
   ],
   "source": [
    "relu = nn.ReLU()\n",
    "input_x = Tensor(np.array([-1, 2, -3, 2, -1]), mindspore.float16)\n",
    "output = relu(input_x)\n",
    "\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### nn.MaxPool2d\n",
    "\n",
    "初始化`nn.MaxPool2d`层，将6×28×28的数组降采样为6×14×14的数组。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 6, 14, 14)\n"
     ]
    }
   ],
   "source": [
    "max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "input_x = Tensor(np.ones([1, 6, 28, 28]), mindspore.float32)\n",
    "\n",
    "print(max_pool2d(input_x).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### nn.Flatten\n",
    "\n",
    "初始化`nn.Flatten`层，将16×5×5的数组转换为400个连续数组。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 400)\n"
     ]
    }
   ],
   "source": [
    "flatten = nn.Flatten()\n",
    "input_x = Tensor(np.ones([1, 16, 5, 5]), mindspore.float32)\n",
    "output = flatten(input_x)\n",
    "\n",
    "print(output.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### nn.Dense\n",
    "\n",
    "初始化`nn.Dense`层，对输入矩阵进行线性变换。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 120)\n"
     ]
    }
   ],
   "source": [
    "dense = nn.Dense(400, 120, weight_init='normal')\n",
    "input_x = Tensor(np.ones([1, 400]), mindspore.float32)\n",
    "output = dense(input_x)\n",
    "\n",
    "print(output.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型参数\n",
    "\n",
    "网络内部的卷积层和全连接层等实例化后，即具有权重和偏置，这些权重和偏置参数会在之后训练中进行优化。`nn.Cell`中使用`parameters_and_names()`方法访问所有参数。\n",
    "\n",
    "在示例中，我们遍历每个参数，并打印网络各层名字和属性。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('conv1.weight', Parameter (name=conv1.weight))\n",
      "('conv2.weight', Parameter (name=conv2.weight))\n",
      "('fc1.weight', Parameter (name=fc1.weight))\n",
      "('fc1.bias', Parameter (name=fc1.bias))\n",
      "('fc2.weight', Parameter (name=fc2.weight))\n",
      "('fc2.bias', Parameter (name=fc2.bias))\n",
      "('fc3.weight', Parameter (name=fc3.weight))\n",
      "('fc3.bias', Parameter (name=fc3.bias))\n"
     ]
    }
   ],
   "source": [
    "model = LeNet5()\n",
    "for m in model.parameters_and_names():\n",
    "    print(m)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
